import pandas as pd
from transformers import pipeline, logging
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score, precision_recall_fscore_support, classification_report
from datasets import load_dataset, DatasetDict, Dataset, disable_progress_bars
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np
import accelerate
import gc
disable_progress_bars()
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.set_option('future.no_silent_downcasting', True)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################
seed = 1
# Define label mapping
id2label = {0: "entailment", 1: "not_entailment"}
label2id = {'entailment':0, 'not_entialment':1}

def compute_metrics_standard(eval_pred, label_text_alphabetical=list(id2label.values())):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)

    # metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro') 
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')
    f1_weighted = f1_score(labels, preds_max, average='weighted')
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)
    mcc = matthews_corrcoef(labels, preds_max)

    metrics = {'MCC': mcc,
            'f1_weighted': f1_weighted,
            'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics

def metrics(df, preds, group_by=None):
    true_col = 'entailment'
    
    def get_metrics(y_true, y_pred):
        return {
            'MCC': matthews_corrcoef(y_true, y_pred),
            'Accuracy': accuracy_score(y_true, y_pred),
            'F1': f1_score(y_true, y_pred, average='weighted')
        }
    
    results = []
    
    if group_by not in ['dataset', 'task']:
        for col in preds:
            metrics = get_metrics(df[true_col], df[col])
            metrics['Column'] = col
            results.append(metrics)
    else:
        for col in preds:
            for group_name, group in df.groupby(group_by):
                metrics = get_metrics(group[true_col], group[col])
                metrics['Column'] = col
                metrics[group_by.capitalize()] = group_name
                results.append(metrics)
    
    results_df = pd.DataFrame(results)
    
    if group_by in ['dataset', 'task']:
        return results_df.set_index(['Column', group_by.capitalize()])
    else:
        return results_df.set_index('Column')

def few_shot(model, shots, data, seed):
    """
    Few shot train a model with a random sample of the data.
    """
    
    tokenizer = AutoTokenizer.from_pretrained(modname)

    train = df.sample(shots, random_state = seed)
    # Create validation set with remaining instances
    val = df[~df.index.isin(train.index)]
    # Create a DatasetDict with train and validation splits
    ds = DatasetDict({'train': Dataset.from_pandas(train, preserve_index=False), 'validation':Dataset.from_pandas(val, preserve_index=False)})
    # Tokenize the dataset
    dstok = ds.map(tokenize_function, batched = True)
    # Rename 'entailment' column to 'label'
    dstok = dstok.rename_columns({'entailment':'label'})
    
    training_args = TrainingArguments(output_dir='../few_shot/',
    logging_dir='../few_shot/',
    lr_scheduler_type= "linear",
    group_by_length=False,
    learning_rate = 9e-6 if 'large' in modname else 3e-5,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 8 if "DeBERTa" in modname else 128,
    gradient_accumulation_steps = 1, 
    num_train_epochs=5,
    warmup_ratio=0.05,  
    weight_decay=0.01, 
    fp16= True if "DeBERTa" in modname else False,   
    fp16_full_eval= True if "DeBERTa" in modname else False,
    bf16 = False if "DeBERTa" in modname else True,
    bf16_full_eval = False if "DeBERTa" in modname else True,
    eval_strategy="no",
    seed=seed,
    save_strategy="no",
    dataloader_num_workers = 0,
    disable_tqdm=True
    )
    
    # Initialize the Trainer
    trainer = Trainer(
        model_init = model_init,
        processing_class=tokenizer,
        args=training_args,
        train_dataset=dstok['train'],
        eval_dataset=dstok['validation'],
        compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(id2label.values()))
    )

    trainer.train()

    return [trainer, dstok]

# Import and format COVID Tweets
df = pd.read_csv('./data/covid_tweets_labeled.csv')
df = df[['text', 'non_comp']]
df['hypothesis'] = 'The author of this tweet does not believe COVID is dangerous.'
df.rename({'text':'premise', 'non_comp':'entailment'}, axis = 1, inplace = True)
df['entailment'].replace({0:1, 1:0}, inplace = True)

base_mcc = []
base_f1 = []
base_acc = []

###############
## DEBATE Base
###############
modname = "mlburnham/Political_DEBATE_DeBERTa_base_v1.1"
tokenizer = AutoTokenizer.from_pretrained(modname)
def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = False, truncation = False)
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

# Import the Pol NLI dataset
polnli = load_dataset('mlburnham/Pol_NLI')
nlitok = polnli.map(tokenize_function, batched = True)
# Rename 'entailment' column to 'label'
nlitok = nlitok.rename_columns({'entailment':'label'})

for i in range(0,20):
    print("DEBATE Base Seed {}".format(i))
    debase, debase_data = few_shot(modname, shots = 25, data = df, seed = i)
    res = debase.predict(nlitok['test'])
    base_mcc.append(res.metrics['test_MCC'])
    base_f1.append(res.metrics['test_f1_weighted'])
    base_acc.append(res.metrics['test_accuracy'])
    # clear memory
    del debase
    gc.collect()
    torch.cuda.empty_cache()

debase_res = pd.DataFrame({'mcc': base_mcc, 'f1': base_f1, 'acc': base_acc})
debase_res['model'] = 'DEBATE Base'

###############
## DEBATE Large
###############
large_mcc = []
large_f1 = []
large_acc = []


modname = "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"
tokenizer = AutoTokenizer.from_pretrained(modname)
def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = False, truncation = False)
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

# Import the Pol NLI dataset
polnli = load_dataset('mlburnham/Pol_NLI')
nlitok = polnli.map(tokenize_function, batched = True)
# Rename 'entailment' column to 'label'
nlitok = nlitok.rename_columns({'entailment':'label'})

for i in range(0,20):
    print("DEBATE Large Seed {}".format(i))
    delarge, delarge_data = few_shot(modname, shots = 25, data = df, seed = i)
    res = delarge.predict(nlitok['test'])
    large_mcc.append(res.metrics['test_MCC'])
    large_f1.append(res.metrics['test_f1_weighted'])
    large_acc.append(res.metrics['test_accuracy'])
    # clear memory
    del delarge
    gc.collect()
    torch.cuda.empty_cache()


delarge_res = pd.DataFrame({'mcc': large_mcc, 'f1': large_f1, 'acc': large_acc})
delarge_res['model'] = 'DEBATE Large'

# clear memory
del delarge
gc.collect()
torch.cuda.empty_cache()

####################
## Modern Bert Base
####################
mbbase_mcc = []
mbbase_f1 = []
mbbase_acc = []

modname = "mlburnham/Political_DEBATE_ModernBERT_base_v1.0"
tokenizer = AutoTokenizer.from_pretrained(modname)
id2label = {0: "entailment", 1: "not_entailment"}
label2id = {"entailment":0, "not_entailment":1}

def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = False, truncation = False)
def model_init():
  return AutoModelForSequenceClassification.from_pretrained(modname, 
                                                           num_labels=2,
                                                           ignore_mismatched_sizes=True,
                                                           label2id = label2id, 
                                                           id2label = id2label,
                                                           torch_dtype = torch.bfloat16)

# Import the Pol NLI dataset
polnli = load_dataset('mlburnham/Pol_NLI')
nlitok = polnli.map(tokenize_function, batched = True)
# Rename 'entailment' column to 'label'
nlitok = nlitok.rename_columns({'entailment':'label'})

for i in range(0,20):
    print("DEBATE MB Base Seed {}".format(i))
    mbbase, mbbase_data = few_shot(modname, shots = 25, data = df, seed = i)
    res = mbbase.predict(nlitok['test'])
    mbbase_mcc.append(res.metrics['test_MCC'])
    mbbase_f1.append(res.metrics['test_f1_weighted'])
    mbbase_acc.append(res.metrics['test_accuracy'])
    # clear memory
    del mbbase
    gc.collect()
    torch.cuda.empty_cache()


mbbase_res = pd.DataFrame({'mcc': mbbase_mcc, 'f1': mbbase_f1, 'acc': mbbase_acc})
mbbase_res['model'] = 'DEBATE MB Base'

# clear memory
del mbbase
gc.collect()
torch.cuda.empty_cache()

####################
## Modern Bert Large
####################
mblarge_mcc = []
mblarge_f1 = []
mblarge_acc = []

modname = "mlburnham/Political_DEBATE_ModernBERT_large_v1.0"
tokenizer = AutoTokenizer.from_pretrained(modname)
id2label = {0: "entailment", 1: "not_entailment"}
label2id = {"entailment":0, "not_entailment":1}

def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = False, truncation = False)
def model_init():
  return AutoModelForSequenceClassification.from_pretrained(modname, 
                                                           num_labels=2,
                                                           ignore_mismatched_sizes=True,
                                                           label2id = label2id, 
                                                           id2label = id2label,
                                                           torch_dtype = torch.bfloat16)

# Import the Pol NLI dataset
polnli = load_dataset('mlburnham/Pol_NLI')
nlitok = polnli.map(tokenize_function, batched = True)
# Rename 'entailment' column to 'label'
nlitok = nlitok.rename_columns({'entailment':'label'})

for i in range(0,20):
    print("DEBATE MB Large Seed {}".format(i))
    mblarge, mblarge_data = few_shot(modname, shots = 25, data = df, seed = i)
    res = mblarge.predict(nlitok['test'])
    mblarge_mcc.append(res.metrics['test_MCC'])
    mblarge_f1.append(res.metrics['test_f1_weighted'])
    mblarge_acc.append(res.metrics['test_accuracy'])
    # clear memory
    del mblarge
    gc.collect()
    torch.cuda.empty_cache()


mblarge_res = pd.DataFrame({'mcc': mblarge_mcc, 'f1': mblarge_f1, 'acc': mblarge_acc})
mblarge_res['model'] = 'DEBATE MB Large'

# clear memory
del mblarge
gc.collect()
torch.cuda.empty_cache()

#####################
## Compile and Export
#####################
df = pd.concat([debase_res, delarge_res, mbbase_res, mblarge_res])
df.to_csv('./data/fewshot_overfit_res.csv', index = False)